{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用结构化生成进行带源高亮的 RAG \n",
    "_作者: [Aymeric Roucher](https://huggingface.co/m-ric)_\n",
    "\n",
    "**结构化生成是一种方法**，它强制 LLN 的输出遵循某些约束，例如遵循特定的模式。\n",
    "\n",
    "这有许多用例：\n",
    "\n",
    "- ✅ 输出一个具有特定键的字典\n",
    "- 📏 确保输出长度超过 N 个字符\n",
    "- ⚙️ 更一般地说，强制输出遵循特定的正则表达式模式以进行下游处理。\n",
    "- 💡 在检索增强生成（RAG）中突出显示支持答案的源\n",
    "\n",
    "在这个 notebook 中，我们特别演示了最后一个用例：\n",
    "\n",
    "**➡️ 我们构建了一个 RAG 系统，它不仅提供答案，还突出显示这个答案所基于的支持片段。**\n",
    "\n",
    "_如果你需要 RAG 的入门介绍，可以查看[这个其他的教程](advanced_rag)。_\n",
    "\n",
    "这个 notebook 首先展示了通过提示进行结构化生成的简单方法，并突出了其局限性，然后演示了受限解码以实现更高效的结构化生成。\n",
    "\n",
    "它利用了 HuggingFace 推理端点（示例展示了一个[无服务器](https://huggingface.co/docs/api-inference/quicktour)端点，但你可以直接将端点更改为[专用](https://huggingface.co/docs/inference-endpoints/en/guides/access)端点），然后还展示了一个使用[outlines](https://github.com/outlines-dev/outlines)，一个结构化文本生成库的本地流水线。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pandas json huggingface_hub pydantic outlines accelerate -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "from huggingface_hub import InferenceClient\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\" I hope you're having a great day! I just wanted to check in and see how things are\""
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "repo_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
    "\n",
    "llm_client = InferenceClient(model=repo_id, timeout=120)\n",
    "\n",
    "# Test your LLM client\n",
    "llm_client.text_generation(prompt=\"How are you today?\", max_new_tokens=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 提示模型\n",
    "\n",
    "为了从模型中获得结构化输出，你可以简单地用适当的指导原则提示一个足够强大的模型，并且大多数时候它应该能够直接工作。\n",
    "\n",
    "在这种情况下，我们希望 RAG 模型不仅生成答案，还生成一个置信度分数和一些源代码片段。\n",
    "我们希望将这些生成为一个 JSON 字典，然后可以轻松地解析它以进行下游处理（在这里，我们将只突出显示源代码片段）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "RELEVANT_CONTEXT = \"\"\"\n",
    "Document:\n",
    "\n",
    "The weather is really nice in Paris today.\n",
    "To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "RAG_PROMPT_TEMPLATE_JSON = \"\"\"\n",
    "Answer the user query based on the source documents.\n",
    "\n",
    "Here are the source documents: {context}\n",
    "\n",
    "\n",
    "You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n",
    "The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n",
    "\n",
    "Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n",
    "\n",
    "Answer:\n",
    "{{\n",
    "  \"answer\": your_answer,\n",
    "  \"confidence_score\": your_confidence_score,\n",
    "  \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n",
    "}}\n",
    "End of answer.\n",
    "\n",
    "Now begin!\n",
    "Here is the user question: {user_query}.\n",
    "Answer:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "USER_QUERY = \"How can I define a stop sequence in Transformers?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Answer the user query based on the source documents.\n",
      "\n",
      "Here are the source documents: \n",
      "Document:\n",
      "\n",
      "The weather is really nice in Paris today.\n",
      "To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n",
      "The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n",
      "\n",
      "Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n",
      "\n",
      "Answer:\n",
      "{\n",
      "  \"answer\": your_answer,\n",
      "  \"confidence_score\": your_confidence_score,\n",
      "  \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n",
      "}\n",
      "End of answer.\n",
      "\n",
      "Now begin!\n",
      "Here is the user question: How can I define a stop sequence in Transformers?.\n",
      "Answer:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n",
    "    context=RELEVANT_CONTEXT, user_query=USER_QUERY\n",
    ")\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"answer\": \"You should pass the stop_sequence argument in your pipeline or model.\",\n",
      "  \"confidence_score\": 0.9,\n",
      "  \"source_snippets\": [\"stop_sequence\", \"pipeline or model\"]\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    max_new_tokens=1000,\n",
    ")\n",
    "\n",
    "answer = answer.split(\"End of answer.\")[0]\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LLM 的输出是一个字典的字符串表示：所以我们只需使用 `literal_eval` 将其作为字典加载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ast import literal_eval\n",
    "\n",
    "parsed_answer = literal_eval(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer: \u001b[1;32mYou should pass the stop_sequence argument in your pipeline or model.\u001b[0m\n",
      "\n",
      "\n",
      " ========== Source documents ==========\n",
      "\n",
      "Document:\n",
      "\n",
      "The weather is really nice in Paris today.\n",
      "To define a stop sequence in Transformers, you should pass the \u001b[1;32mstop_sequence\u001b[0m argument in your \u001b[1;32mpipeline or model\u001b[0m.\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def highlight(s):\n",
    "    return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n",
    "\n",
    "\n",
    "def print_results(answer, source_text, highlight_snippets):\n",
    "    print(\"Answer:\", highlight(answer))\n",
    "    print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n",
    "    for snippet in highlight_snippets:\n",
    "        source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n",
    "    print(source_text)\n",
    "\n",
    "\n",
    "print_results(\n",
    "    parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "成功了！🥳\n",
    "\n",
    "但是使用一个不那么强大的模型会怎么样呢？\n",
    "\n",
    "为了模拟一个不那么强大的模型可能产生的连贯性较差的输出，我们增加了温度（temperature）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"answer\": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes\tRender“He caτων不是來rates‏ 그런Received05jet �\tDECLAREed \"]\";\n",
      "Top Access臣Zen PastFlow.TabBand                                                \n",
      ".Assquoas 믿锦encers relativ巨 durations........ $块 leftｲStaffuddled/HlibBR、【(cardospelrowth)\\<午…)_SHADERprovided[\"_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro\tdis}')\n",
      "———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,\n",
      "_FOR NI/disable зн 主posureWiders,latRU_BUSY{amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th\tsend-per '</xed.Staticdate sure-ro\\\\\\\\ censuskillsSystemsMuch askingNETWORK ')\n",
      ".system.map_stringfe terrorismieXXX lett<Mexit Json_=pixels.tt_\n",
      "`,] ­/\n",
      " stoutsteam 〈\"httpWINDOWEnumerator turning扶Image)}tomav%\">\n",
      "nicasv:<:',\n",
      "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% {} scenes$c       \n",
      "\n",
      "T unk � заним solidity Steinمῆ period bindcannot\">\n",
      "\n",
      ".ال،\n",
      "\"' Bol\n"
     ]
    }
   ],
   "source": [
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    max_new_tokens=250,\n",
    "    temperature=1.6,\n",
    "    return_full_text=False,\n",
    ")\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，输出甚至不是正确的 JSON 格式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 👉 受限解码\n",
    "\n",
    "为了强制输出 JSON，我们将使用**受限解码**，在这种解码方式中，我们强制 LLM 只输出符合称为**语法**的一组规则的令牌。\n",
    "\n",
    "这个语法可以使用 Pydantic 模型、JSON 模式或正则表达式来定义。然后 AI 将生成符合指定语法的响应。\n",
    "\n",
    "例如，这里我们遵循[Pydantic 类型](https://docs.pydantic.dev/latest/api/types/)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel, confloat, StringConstraints\n",
    "from typing import List, Annotated\n",
    "\n",
    "\n",
    "class AnswerWithSnippets(BaseModel):\n",
    "    answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n",
    "    confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n",
    "    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我建议检查生成的模式，以确保它正确地表示了你的需求："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'properties': {'answer': {'maxLength': 100,\n",
       "   'minLength': 10,\n",
       "   'title': 'Answer',\n",
       "   'type': 'string'},\n",
       "  'confidence': {'title': 'Confidence', 'type': 'number'},\n",
       "  'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n",
       "   'title': 'Source Snippets',\n",
       "   'type': 'array'}},\n",
       " 'required': ['answer', 'confidence', 'source_snippets'],\n",
       " 'title': 'AnswerWithSnippets',\n",
       " 'type': 'object'}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "AnswerWithSnippets.schema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你可以使用客户端的 `text_generation` 方法，或者使用其 `post` 方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"answer\": \"You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因\",\n",
      "            \"confidence\": 0.9,\n",
      "            \"source_snippets\": [\"in Transformers\", \"stop_sequence argument in your\"]\n",
      "            }\n",
      "{\n",
      "\"answer\": \"To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ\",  \"confidence\": 1,  \"source_snippets\": [\"seq이야\",\"stration nhiên thị ji是什么hpeldo\"]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Using text_generation\n",
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
    "    max_new_tokens=250,\n",
    "    temperature=1.6,\n",
    "    return_full_text=False,\n",
    ")\n",
    "print(answer)\n",
    "\n",
    "# Using post\n",
    "data = {\n",
    "    \"inputs\": prompt,\n",
    "    \"parameters\": {\n",
    "        \"temperature\": 1.6,\n",
    "        \"return_full_text\": False,\n",
    "        \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
    "        \"max_new_tokens\": 250,\n",
    "    },\n",
    "}\n",
    "answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ 尽管由于温度较高，答案仍然没有意义，但现在生成的输出是正确的 JSON 格式，具有我们在语法中定义的确切键和类型！\n",
    "\n",
    "然后它可以被解析以进行进一步处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用 Outlines 在本地流水线上应用语法\n",
    "\n",
    "[Outlines](https://github.com/outlines-dev/outlines/) 是在我们的推理 API 底层运行的库，用于约束输出生成。你也可以在本地使用它。\n",
    "\n",
    "它通过 [在 logits 上施加 bias](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143) 来强制选择仅符合你约束的选项。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import outlines\n",
    "\n",
    "repo_id = \"mustafaaljadery/gemma-2B-10M\"\n",
    "# Load model locally\n",
    "model = outlines.models.transformers(repo_id)\n",
    "\n",
    "schema_as_str = json.dumps(AnswerWithSnippets.schema())\n",
    "\n",
    "generator = outlines.generate.json(model, schema_as_str)\n",
    "\n",
    "# Use the `generator` to sample an output from the model\n",
    "result = generator(prompt)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你还可以使用 [文本生成推理](https://huggingface.co/docs/text-generation-inference/en/index) 进行受限生成（请参阅 [文档](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance) 以获取更多详细信息和示例）。\n",
    "\n",
    "现在我们已经展示了一个特定的 RAG 用例，但受限生成对于更多的事情都非常有帮助。\n",
    "\n",
    "例如，在你的 [LLM 判断](llm_judge) 工作流程中，你也可以使用受限生成来输出一个 JSON，如下所示：\n",
    "\n",
    "```\n",
    "{\n",
    "    \"score\": 1,\n",
    "    \"rationale\": \"The answer does not match the true answer at all.\",\n",
    "    \"confidence_level\": 0.85\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "今天的内容就到这里，恭喜你跟到最后！👏"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cookbook",
   "language": "python",
   "name": "cookbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
